{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import json\n",
    "import random\n",
    "from datasets import load_dataset\n",
    "device = \"cuda\" # the device to load the model onto"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.32s/it]\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    \"Qwen2-7B-Instruct\",  # path2Qwen2\n",
    "    torch_dtype=\"auto\",\n",
    "    device_map=\"auto\"\n",
    ")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"Qwen2-7B-Instruct\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "def chat(system, prompt, model=model, tokenizer=tokenizer):\n",
    "\n",
    "    messages = [\n",
    "    {\"role\": \"system\", \"content\": system},\n",
    "    {\"role\": \"user\", \"content\": prompt}\n",
    "        ]\n",
    "    text = tokenizer.apply_chat_template(\n",
    "        messages,\n",
    "        tokenize=False,\n",
    "        add_generation_prompt=True\n",
    "    )\n",
    "    model_inputs = tokenizer([text], return_tensors=\"pt\").to(device)\n",
    "\n",
    "    generated_ids = model.generate(\n",
    "        model_inputs.input_ids,\n",
    "        max_new_tokens=768\n",
    "    )\n",
    "    generated_ids = [\n",
    "        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)\n",
    "    ]\n",
    "\n",
    "    return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['instruction', 'input', 'output'],\n",
       "    num_rows: 10\n",
       "})"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "datasets='GAOKAO_math'\n",
    "content = load_dataset('json', data_files=f'Eval/dataset/{datasets}.jsonl',split='train')\n",
    "content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The question is: 问题是:已知向量 ā = (0, 1), b̄ = (2, x)，若 b̄ ⊥ (b̄ - 4ā)，则 x = (  ), 选项是:A: -2, B: -1, C: 1, D: 2, 您的答案是:\n",
      "The prediction is:  要解决这个问题，我们首先需要理解题目中的条件 \"b̄ ⊥ (b̄ - 4ā)\"，这表示向量 b̄ 与向量 (b̄ - 4ā) 是垂直的。两个向量垂直意味着它们的点积为0。\n",
      "\n",
      "给定向量 ā = (0, 1) 和 b̄ = (2, x)，我们可以计算出 (b̄ - 4ā) = (2 - 4*0, x - 4*1) = (2, x - 4)。\n",
      "\n",
      "两个向量的点积公式为：\\[ \\vec{a} \\cdot \\vec{b} = a_xb_x + a_yb_y \\]\n",
      "\n",
      "所以，对于向量 b̄ 和 (b̄ - 4ā)，我们有：\n",
      "\n",
      "\\[ b̄ \\cdot (b̄ - 4ā) = 2*2 + x*(x-4) = 0 \\]\n",
      "\n",
      "\\[ 4 + x^2 - 4x = 0 \\]\n",
      "\n",
      "整理得：\n",
      "\n",
      "\\[ x^2 - 4x + 4 = 0 \\]\n",
      "\n",
      "这是一个完全平方公式，可以写为：\n",
      "\n",
      "\\[ (x - 2)^2 = 0 \\]\n",
      "\n",
      "因此，解得：\n",
      "\n",
      "\\[ x = 2 \\]\n",
      "\n",
      "所以，正确答案是 D: 2。\n",
      "The answer is:  D\n"
     ]
    }
   ],
   "source": [
    "idx = random.randint(0, len(content))\n",
    "instruction, question, answer = content[idx].values()\n",
    "prompt = '{question}'.format(question=question)\n",
    "response = chat(instruction, prompt)\n",
    "print(\"The question is:\", prompt)\n",
    "print(\"The prediction is: \", response)\n",
    "print(\"The answer is: \", answer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xhr",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
